# Scripts that load operators and generate the metadata
import argparse
import logging
import os
import sys

import yaml

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


from os.path import abspath, exists

CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))


def setup_tritonbench_cwd():
    original_dir = abspath(os.getcwd())

    for tritonbench_dir in (
        ".",
        "../../../tritonbench",
    ):
        if exists(tritonbench_dir):
            break

    if exists(tritonbench_dir):
        tritonbench_dir = abspath(tritonbench_dir)
        os.chdir(tritonbench_dir)
        sys.path.append(tritonbench_dir)
    return original_dir


setup_tritonbench_cwd()

from tritonbench.operators import list_operators, load_opbench_by_name

# operators that are not supported by tritonbench-oss
SKIP_OPERATORS = ["decoding_attention"]
# operators that are too small to measure its tflops
# we will skip them even they can measure tflops
TFLOPS_SKIP_OPERATORS = ["low_mem_dropout"]

HEADER_TEMPLATE = """# =================================================================
# This file is generated by benchmarks/gen_metadata/run.py
# {DESCRIPTION}
# =================================================================
"""

BACKWARD_OPERATORS = []
TFLOPS_OPERATORS = []
BASELINE_OPERATORS = {}
DTYPE_OPERATORS = {}

METADATA_MAPPING = {
    "backward": (
        BACKWARD_OPERATORS,
        "backward_operators.yaml",
        "List of operators that support backward pass",
    ),
    "baseline": (
        BASELINE_OPERATORS,
        "baseline_operators.yaml",
        "List of operators that have baseline benchmark for accuracy and speedup.",
    ),
    "tflops": (
        TFLOPS_OPERATORS,
        "tflops_operators.yaml",
        "List of operators that support tflops metric",
    ),
    "dtype": (
        DTYPE_OPERATORS,
        "dtype_operators.yaml",
        "List of operators and their default dtype",
    ),
}


def run(args: argparse.Namespace):
    operators = list_operators()
    for op in operators:
        if op in SKIP_OPERATORS:
            continue
        logger.info(f"Processing operator {op} ...")
        op_bench = load_opbench_by_name(op_name=op)
        DTYPE_OPERATORS[op] = op_bench.DEFAULT_PRECISION
        if baseline := op_bench.has_baseline():
            BASELINE_OPERATORS[op] = baseline
        if op_bench.has_metric("tflops") and not op in TFLOPS_SKIP_OPERATORS:
            TFLOPS_OPERATORS.append(op)
        if op_bench.has_bwd():
            BACKWARD_OPERATORS.append(op)
    for k in METADATA_MAPPING.keys():
        obj, fname, description = METADATA_MAPPING[k]
        output_file = os.path.join(args.output, fname)
        with open(output_file, "w") as out:
            logger.info(f"Writing {k} metadata to {output_file}")
            yaml_str = yaml.safe_dump(obj, sort_keys=False)
            out.write(HEADER_TEMPLATE.format(DESCRIPTION=description) + yaml_str)


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--output",
        type=str,
        default=os.path.join(CURRENT_DIR, "metadata"),
        help="generate metadata yaml files to the specific directory",
    )
    args = parser.parse_args()
    run(args)


if __name__ == "__main__":
    # Do not add code here, it won't be run. Add them to the function called below.
    main()  # pragma: no cover
